# RUN: %PYTHON %s | FileCheck %s

import gc
from mlir.ir import *


def run(f):
    print("\nTEST:", f.__name__)
    f()
    gc.collect()
    assert Context._get_live_count() == 0
    return f


# CHECK-LABEL: TEST: testAffineExprCapsule
@run
def testAffineExprCapsule():
    with Context() as ctx:
        affine_expr = AffineExpr.get_constant(42)

    affine_expr_capsule = affine_expr._CAPIPtr
    # CHECK: capsule object
    # CHECK: mlir.ir.AffineExpr._CAPIPtr
    print(affine_expr_capsule)

    affine_expr_2 = AffineExpr._CAPICreate(affine_expr_capsule)
    assert affine_expr == affine_expr_2
    assert affine_expr_2.context == ctx


# CHECK-LABEL: TEST: testAffineExprEq
@run
def testAffineExprEq():
    with Context():
        a1 = AffineExpr.get_constant(42)
        a2 = AffineExpr.get_constant(42)
        a3 = AffineExpr.get_constant(43)
        # CHECK: True
        print(a1 == a1)
        # CHECK: True
        print(a1 == a2)
        # CHECK: False
        print(a1 == a3)
        # CHECK: False
        print(a1 is None)
        # CHECK: False
        print(a1 == "foo")


# CHECK-LABEL: TEST: testAffineExprContext
@run
def testAffineExprContext():
    with Context():
        a1 = AffineExpr.get_constant(42)
    with Context():
        a2 = AffineExpr.get_constant(42)

    # CHECK: False
    print(a1 == a2)


run(testAffineExprContext)


# CHECK-LABEL: TEST: testAffineExprConstant
@run
def testAffineExprConstant():
    with Context():
        a1 = AffineExpr.get_constant(42)
        # CHECK: 42
        print(a1.value)
        # CHECK: 42
        print(a1)

        a2 = AffineConstantExpr.get(42)
        # CHECK: 42
        print(a2.value)
        # CHECK: 42
        print(a2)

        assert a1 == a2


# CHECK-LABEL: TEST: testAffineExprDim
@run
def testAffineExprDim():
    with Context():
        d1 = AffineExpr.get_dim(1)
        d11 = AffineDimExpr.get(1)
        d2 = AffineDimExpr.get(2)

        # CHECK: 1
        print(d1.position)
        # CHECK: d1
        print(d1)

        # CHECK: 2
        print(d2.position)
        # CHECK: d2
        print(d2)

        assert d1 == d11
        assert d1 != d2


# CHECK-LABEL: TEST: testAffineExprSymbol
@run
def testAffineExprSymbol():
    with Context():
        s1 = AffineExpr.get_symbol(1)
        s11 = AffineSymbolExpr.get(1)
        s2 = AffineSymbolExpr.get(2)

        # CHECK: 1
        print(s1.position)
        # CHECK: s1
        print(s1)

        # CHECK: 2
        print(s2.position)
        # CHECK: s2
        print(s2)

        assert s1 == s11
        assert s1 != s2


# CHECK-LABEL: TEST: testAffineAddExpr
@run
def testAffineAddExpr():
    with Context():
        d1 = AffineDimExpr.get(1)
        d2 = AffineDimExpr.get(2)
        d12 = AffineExpr.get_add(d1, d2)
        # CHECK: d1 + d2
        print(d12)

        d12op = d1 + d2
        # CHECK: d1 + d2
        print(d12op)

        d1cst_op = d1 + 2
        # CHECK: d1 + 2
        print(d1cst_op)

        d1cst_op2 = 2 + d1
        # CHECK: d1 + 2
        print(d1cst_op2)

        assert d12 == d12op
        assert d12.lhs == d1
        assert d12.rhs == d2


# CHECK-LABEL: TEST: testAffineMulExpr
@run
def testAffineMulExpr():
    with Context():
        d1 = AffineDimExpr.get(1)
        c2 = AffineConstantExpr.get(2)
        expr = AffineExpr.get_mul(d1, c2)
        # CHECK: d1 * 2
        print(expr)

        # CHECK: d1 * 2
        op = d1 * c2
        print(op)

        # CHECK: d1 * 2
        op_cst = d1 * 2
        print(op_cst)

        # CHECK: d1 * 2
        op_cst2 = 2 * d1
        print(op_cst2)

        assert expr == op
        assert expr == op_cst
        assert expr.lhs == d1
        assert expr.rhs == c2


# CHECK-LABEL: TEST: testAffineModExpr
@run
def testAffineModExpr():
    with Context():
        d1 = AffineDimExpr.get(1)
        c2 = AffineConstantExpr.get(2)
        expr = AffineExpr.get_mod(d1, c2)
        # CHECK: d1 mod 2
        print(expr)

        # CHECK: d1 mod 2
        op = d1 % c2
        print(op)

        # CHECK: d1 mod 2
        op_cst = d1 % 2
        print(op_cst)

        # CHECK: 2 mod d1
        print(2 % d1)

        assert expr == op
        assert expr == op_cst
        assert expr.lhs == d1
        assert expr.rhs == c2

        expr2 = AffineExpr.get_mod(c2, d1)
        expr3 = AffineExpr.get_mod(2, d1)
        expr4 = AffineExpr.get_mod(d1, 2)

        # CHECK: 2 mod d1
        print(expr2)
        # CHECK: 2 mod d1
        print(expr3)
        # CHECK: d1 mod 2
        print(expr4)

        assert expr2 == expr3
        assert expr4 == expr


# CHECK-LABEL: TEST: testAffineFloorDivExpr
@run
def testAffineFloorDivExpr():
    with Context():
        d1 = AffineDimExpr.get(1)
        c2 = AffineConstantExpr.get(2)
        expr = AffineExpr.get_floor_div(d1, c2)
        # CHECK: d1 floordiv 2
        print(expr)

        assert expr.lhs == d1
        assert expr.rhs == c2

        expr2 = AffineExpr.get_floor_div(c2, d1)
        expr3 = AffineExpr.get_floor_div(2, d1)
        expr4 = AffineExpr.get_floor_div(d1, 2)

        # CHECK: 2 floordiv d1
        print(expr2)
        # CHECK: 2 floordiv d1
        print(expr3)
        # CHECK: d1 floordiv 2
        print(expr4)

        assert expr2 == expr3
        assert expr4 == expr


# CHECK-LABEL: TEST: testAffineCeilDivExpr
@run
def testAffineCeilDivExpr():
    with Context():
        d1 = AffineDimExpr.get(1)
        c2 = AffineConstantExpr.get(2)
        expr = AffineExpr.get_ceil_div(d1, c2)
        # CHECK: d1 ceildiv 2
        print(expr)

        assert expr.lhs == d1
        assert expr.rhs == c2

        expr2 = AffineExpr.get_ceil_div(c2, d1)
        expr3 = AffineExpr.get_ceil_div(2, d1)
        expr4 = AffineExpr.get_ceil_div(d1, 2)

        # CHECK: 2 ceildiv d1
        print(expr2)
        # CHECK: 2 ceildiv d1
        print(expr3)
        # CHECK: d1 ceildiv 2
        print(expr4)

        assert expr2 == expr3
        assert expr4 == expr


# CHECK-LABEL: TEST: testAffineExprSub
@run
def testAffineExprSub():
    with Context():
        d1 = AffineDimExpr.get(1)
        d2 = AffineDimExpr.get(2)
        expr = d1 - d2
        # CHECK: d1 - d2
        print(expr)

        assert expr.lhs == d1
        rhs = AffineMulExpr(expr.rhs)
        # CHECK: d2
        print(rhs.lhs)
        # CHECK: -1
        print(rhs.rhs)

        # CHECK: d1 - 42
        print(d1 - 42)
        # CHECK: -d1 + 42
        print(42 - d1)

        c42 = AffineConstantExpr.get(42)
        assert d1 - 42 == d1 - c42
        assert 42 - d1 == c42 - d1


# CHECK-LABEL: TEST: testClassHierarchy
@run
def testClassHierarchy():
    with Context():
        d1 = AffineDimExpr.get(1)
        c2 = AffineConstantExpr.get(2)
        add = AffineAddExpr.get(d1, c2)
        mul = AffineMulExpr.get(d1, c2)
        mod = AffineModExpr.get(d1, c2)
        floor_div = AffineFloorDivExpr.get(d1, c2)
        ceil_div = AffineCeilDivExpr.get(d1, c2)

        # CHECK: False
        print(isinstance(d1, AffineBinaryExpr))
        # CHECK: False
        print(isinstance(c2, AffineBinaryExpr))
        # CHECK: True
        print(isinstance(add, AffineBinaryExpr))
        # CHECK: True
        print(isinstance(mul, AffineBinaryExpr))
        # CHECK: True
        print(isinstance(mod, AffineBinaryExpr))
        # CHECK: True
        print(isinstance(floor_div, AffineBinaryExpr))
        # CHECK: True
        print(isinstance(ceil_div, AffineBinaryExpr))

        try:
            AffineBinaryExpr(d1)
        except ValueError as e:
            # CHECK: Cannot cast affine expression to AffineBinaryExpr
            print(e)

        try:
            AffineBinaryExpr(c2)
        except ValueError as e:
            # CHECK: Cannot cast affine expression to AffineBinaryExpr
            print(e)


# CHECK-LABEL: TEST: testIsInstance
@run
def testIsInstance():
    with Context():
        d1 = AffineDimExpr.get(1)
        c2 = AffineConstantExpr.get(2)
        add = AffineAddExpr.get(d1, c2)
        mul = AffineMulExpr.get(d1, c2)

        # CHECK: True
        print(AffineDimExpr.isinstance(d1))
        # CHECK: False
        print(AffineConstantExpr.isinstance(d1))
        # CHECK: True
        print(AffineConstantExpr.isinstance(c2))
        # CHECK: False
        print(AffineMulExpr.isinstance(c2))
        # CHECK: True
        print(AffineAddExpr.isinstance(add))
        # CHECK: False
        print(AffineMulExpr.isinstance(add))
        # CHECK: True
        print(AffineMulExpr.isinstance(mul))
        # CHECK: False
        print(AffineAddExpr.isinstance(mul))


# CHECK-LABEL: TEST: testCompose
@run
def testCompose():
    with Context():
        # d0 + d2.
        expr = AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(2))

        # (d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)
        map1 = AffineAddExpr.get(AffineDimExpr.get(0), AffineSymbolExpr.get(1))
        map2 = AffineAddExpr.get(AffineDimExpr.get(1), AffineSymbolExpr.get(0))
        map3 = AffineAddExpr.get(
            AffineAddExpr.get(AffineDimExpr.get(0), AffineDimExpr.get(1)),
            AffineDimExpr.get(2),
        )
        map = AffineMap.get(3, 2, [map1, map2, map3])

        # CHECK: d0 + s1 + d0 + d1 + d2
        print(expr.compose(map))


# CHECK-LABEL: TEST: testHash
@run
def testHash():
    with Context():
        d0 = AffineDimExpr.get(0)
        s1 = AffineSymbolExpr.get(1)
        assert hash(d0) == hash(AffineDimExpr.get(0))
        assert hash(d0 + s1) == hash(AffineAddExpr.get(d0, s1))

        dictionary = dict()
        dictionary[d0] = 0
        dictionary[s1] = 1
        assert d0 in dictionary
        assert s1 in dictionary


# CHECK-LABEL: TEST: testAffineExprShift
@run
def testAffineExprShift():
    with Context() as ctx:
        dims = [AffineExpr.get_dim(i) for i in range(4)]
        syms = [AffineExpr.get_symbol(i) for i in range(4)]

        assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
        assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)


# CHECK-LABEL: TEST: testAffineExprSimplify
@run
def testAffineExprSimplify():
    with Context() as ctx:
        expr = AffineExpr.get_dim(0) + AffineExpr.get_symbol(0)
        assert expr == AffineExpr.simplify_affine_expr(expr, 1, 1)
